Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions perfect_information_game/move_selection/mcts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from perfect_information_game.move_selection.mcts.abstract_node import AbstractNode
from perfect_information_game.move_selection.mcts.rollout_node import RolloutNode
from perfect_information_game.move_selection.mcts.heuristic_node import HeuristicNode
from perfect_information_game.move_selection.mcts.tablebase_node import TablebaseNode
from perfect_information_game.move_selection.mcts.mcts import MCTS
from perfect_information_game.move_selection.mcts.async_mcts import AsyncMCTS
98 changes: 75 additions & 23 deletions perfect_information_game/move_selection/mcts/abstract_node.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
from abc import ABC, abstractmethod
import numpy as np
from perfect_information_game.tablebases import EmptyTablebaseManager


class AbstractNode(ABC):
def __init__(self, position, parent, GameClass, c=np.sqrt(2), verbose=False):
def __init__(self, position, parent, GameClass, tablebase_manager=None, verbose=False):
self.position = position
self.parent = parent
self.GameClass = GameClass
self.c = c
self.fully_expanded = GameClass.is_over(position)
self.tablebase_manager = tablebase_manager if tablebase_manager is not None \
else EmptyTablebaseManager(GameClass)
self.verbose = verbose

if self.GameClass.is_over(position):
self.fully_expanded = True
self.outcome = GameClass.get_winner(position)
else:
self.fully_expanded = False
self.outcome = None # this will be None if fully_expanded is False, otherwise it will be either -1, 0, or 1

self.is_maximizing = GameClass.is_player_1_turn(position)
self.children = None
self.verbose = verbose

@abstractmethod
def get_evaluation(self):
"""
Returns the prediction of the evaluation of the game assuming perfect play, which is in the range [-1, 1].
This will be 1 if player 1 is winning, 0 if it is a draw, and -1 if player 1 is losing.
"""
pass

@abstractmethod
def count_expansions(self):
"""
Returns the number of times that this node has been explored.
If the node is fully expanded, then np.inf will be returned.
"""
pass

@abstractmethod
Expand All @@ -33,11 +50,15 @@ def get_puct_heuristic_for_child(self, i):
def expand(self):
pass

@abstractmethod
def set_fully_expanded(self, minimax_evaluation):
pass
self.fully_expanded = True
self.outcome = minimax_evaluation

def choose_best_node(self, return_probability_distribution=False, optimal=False):
"""
Chooses the best move based on the expansions performed so far.
This should only need to be called on the root node in order to choose a move for the AI.
"""
distribution = []

optimal_value = 1 if self.is_maximizing else -1
Expand All @@ -55,10 +76,9 @@ def choose_best_node(self, return_probability_distribution=False, optimal=False)
if child.fully_expanded and child.get_evaluation() == self.get_evaluation():
# TODO: when losing, consider the number of ways the opponent can win in response to a move
depth_to_endgame = child.depth_to_end_game()
# if we are winning, weight smaller depths much more strongly by using e^-x
# if we are losing or drawing, weight larger depths much more strongly by using e^x
relative_probability = np.exp(-depth_to_endgame if self.get_evaluation() == optimal_value else
depth_to_endgame)
# if we want a short game, weight smaller depths much more strongly by using e^-x
# if we want a long game, weight larger depths much more strongly by using e^x
relative_probability = np.exp(depth_to_endgame if self.want_long_game() else -depth_to_endgame)
distribution.append(relative_probability)
else:
distribution.append(0)
Expand All @@ -81,12 +101,24 @@ def choose_best_node(self, return_probability_distribution=False, optimal=False)
best_child = self.children[idx]
return (best_child, distribution) if return_probability_distribution else best_child

def choose_expansion_node(self):
def choose_expansion_node(self, search_suboptimal=False):
"""
Searches the tree to find a node to expand.
Returns None if no nodes could be found because they are all fully expanded.
:param search_suboptimal: If True, then nodes will continue to be searched even if
they have siblings that lead to a win.
This should only be set to True once the entire tree has been searched and the best line has been determined.
"""
if search_suboptimal:
raise NotImplementedError()

# TODO: continue tree search in case the user makes a mistake and the game continues
if self.fully_expanded:
# self must be the root node because a fully expanded node would never be chosen by its parent
return None

if self.count_expansions() == 0:
# this node itself has never been expanded
return self

self.ensure_children()
Expand All @@ -99,11 +131,13 @@ def choose_expansion_node(self):
# If this child is already optimal, then self is fully expanded and there is no point searching further
if child.get_evaluation() == optimal_value:
self.set_fully_expanded(optimal_value)
# delegate to the parent, which will now choose differently since self is fully expanded
return self.parent.choose_expansion_node() if self.parent is not None else None
# continue searching other children, there may be another child that is more optimal
continue

# check puct heuristic before calling child.get_evaluation() because it may result in division by 0
# check puct heuristic before calling child.get_evaluation() because
# it may result in division by 0 for RolloutNode
puct_heuristic = self.get_puct_heuristic_for_child(i)
if np.isinf(puct_heuristic):
return child
Expand All @@ -122,30 +156,48 @@ def choose_expansion_node(self):
# if nothing was found because all children are fully expanded
if best_child is None:
if self.verbose and not self.fully_expanded and self.parent is None:
# print this message when the root node becomes fully expanded so that it is only printed once
print('Fully expanded tree!')

minimax_evaluation = max([child.get_evaluation() for child in self.children]) if self.is_maximizing \
else min([child.get_evaluation() for child in self.children])
minimax_evaluation = (max if self.is_maximizing else min)(
[child.get_evaluation() for child in self.children])
self.set_fully_expanded(minimax_evaluation)
# this node is now fully expanded, so ask the parent to try to choose again
# if no parent is available (i.e. this is the root node) then the entire search tree has been expanded
return self.parent.choose_expansion_node() if self.parent is not None else None

# the best child has been chosen, and the expansion node choice is delegated to it now
return best_child.choose_expansion_node()

def depth_to_end_game(self):
def want_long_game(self):
"""
Can only be called on fully expanded nodes.

Returns True if we want the game to be as long as possible and
False if we want the game to be as fast as possible.
"""
if not self.fully_expanded:
raise Exception('Node not fully expanded!')

if self.children is None:
return 0

optimal_value = 1 if self.is_maximizing else -1
if self.get_evaluation() == optimal_value:
if self.outcome == optimal_value:
# if we are winning, win as fast as possible
return 1 + min(child.depth_to_end_game() for child in self.children
if child.fully_expanded and child.get_evaluation() == self.get_evaluation())
return False
elif self.outcome == -optimal_value:
# if we are losing, lose as slow as possible
return True
else:
# if we are losing or it is a draw, lose as slow as possible
return 1 + max(child.depth_to_end_game() for child in self.children
if child.fully_expanded and child.get_evaluation() == self.get_evaluation())
try:
heuristic = self.GameClass.heuristic(self.position)
# if the game is a draw, make it as slow as possible if we have a material advantage
# otherwise, finish the game as fast as possible
return heuristic > 0 if self.is_maximizing else heuristic < 0
except NotImplementedError:
# if the game is a draw and no heuristic is defined, then finish the game as fast as possible
return False

def depth_to_end_game(self):
return 1 + (max if self.want_long_game() else min)(
child.depth_to_end_game() for child in self.children
if child.fully_expanded and child.get_evaluation() == self.get_evaluation())
136 changes: 71 additions & 65 deletions perfect_information_game/move_selection/mcts/async_mcts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from multiprocessing import Pipe, Pool
from multiprocessing import Pipe
from multiprocessing.context import Process
from time import time
import numpy as np
from perfect_information_game.move_selection.mcts import HeuristicNode
from perfect_information_game.move_selection.mcts import RolloutNode
from perfect_information_game.move_selection import MoveChooser
from perfect_information_game.move_selection.mcts import TablebaseNode, RolloutNode, HeuristicNode
from perfect_information_game.tablebases import EmptyTablebaseManager
from perfect_information_game.utils import OptionalPool


class AsyncMCTS(MoveChooser):
Expand All @@ -13,7 +14,8 @@ class AsyncMCTS(MoveChooser):
This is achieved using multiprocessing, and a Pipe for transferring data to and from the worker process.
"""

def __init__(self, GameClass, starting_position, time_limit=3, network=None, c=np.sqrt(2), d=1, threads=1):
def __init__(self, GameClass, starting_position, time_limit=3, network=None, c=np.sqrt(2), d=1, threads=1,
tablebase_manager=None):
"""
Either:
If network is provided, threads must be 1.
Expand All @@ -23,10 +25,12 @@ def __init__(self, GameClass, starting_position, time_limit=3, network=None, c=n
if network is not None and threads != 1:
raise ValueError('Threads != 1 with Network != None')

if tablebase_manager is None:
tablebase_manager = EmptyTablebaseManager(GameClass)
self.parent_pipe, worker_pipe = Pipe()
self.worker_process = Process(target=self.loop_func,
args=(GameClass, starting_position, time_limit, network, c, d, threads,
worker_pipe))
tablebase_manager, worker_pipe))

def start(self):
self.worker_process.start()
Expand Down Expand Up @@ -60,69 +64,71 @@ def terminate(self):
self.worker_process.join()

@staticmethod
def loop_func(GameClass, position, time_limit, network, c, d, threads, worker_pipe):
if network is None:
pool = Pool(threads) if threads > 1 else None
root = RolloutNode(position, parent=None, GameClass=GameClass, c=c, rollout_batch_size=threads, pool=pool,
verbose=True)
else:
network.initialize()
root = HeuristicNode(position, None, GameClass, network, c, d, verbose=True)

while True:
best_node = root.choose_expansion_node()

if best_node is not None:
best_node.expand()

if root.children is not None and worker_pipe.poll():
user_chosen_position = worker_pipe.recv()

if user_chosen_position is not None:
# an updated position has been received so we can truncate the tree
for child in root.children:
if np.all(child.position == user_chosen_position):
root = child
root.parent = None
def loop_func(GameClass, position, time_limit, network, c, d, threads, tablebase_manager, worker_pipe):
with OptionalPool(threads) as pool:
if network is not None:
network.initialize()

root = TablebaseNode.attempt_create(position, None, GameClass, tablebase_manager, verbose=True,
backup_factory=lambda:
RolloutNode(position, None, GameClass, c, threads, pool, verbose=True)
if network is None else
HeuristicNode(position, None, GameClass, network, c, d, verbose=True))

while True:
best_node = root.choose_expansion_node()

if best_node is not None:
best_node.expand()

if root.children is not None and worker_pipe.poll():
user_chosen_position = worker_pipe.recv()

if user_chosen_position is not None:
# an updated position has been received so we can truncate the tree
for child in root.children:
if np.all(child.position == user_chosen_position):
root = child
root.parent = None
break
else:
print(user_chosen_position)
raise Exception('Invalid user chosen move!')

if GameClass.is_over(root.position):
print('Game Over in Async MCTS: ', GameClass.get_winner(root.position))
break
else:
print(user_chosen_position)
raise Exception('Invalid user chosen move!')

if GameClass.is_over(root.position):
print('Game Over in Async MCTS: ', GameClass.get_winner(root.position))
return
else:
# this move chooser has been requested to decide on a move via the choose_move function
start_time = time()
while time() - start_time < time_limit:
best_node = root.choose_expansion_node()

# best_node will be None if the tree is fully expanded
if best_node is None:
break

best_node.expand()

is_ai_player_1 = GameClass.is_player_1_turn(root.position)
chosen_positions = []
print(f'MCTS choosing move based on {root.count_expansions()} expansions!')

# choose moves as long as it is still the ai's turn
while GameClass.is_player_1_turn(root.position) == is_ai_player_1:
if root.children is None:
# this move chooser has been requested to decide on a move via the choose_move function
start_time = time()
while time() - start_time < time_limit:
best_node = root.choose_expansion_node()
if best_node is not None:
best_node.expand()
root, distribution = root.choose_best_node(return_probability_distribution=True, optimal=True)
chosen_positions.append((root.position, distribution))

print('Expected outcome: ', root.get_evaluation())
root.parent = None # delete references to the parent and siblings
worker_pipe.send(chosen_positions)
if GameClass.is_over(root.position):
print('Game Over in Async MCTS: ', GameClass.get_winner(root.position))
return

# best_node will be None if the tree is fully expanded
if best_node is None:
break

best_node.expand()

is_ai_player_1 = GameClass.is_player_1_turn(root.position)
chosen_positions = []
print(f'MCTS choosing move based on {root.count_expansions()} expansions!')

# choose moves as long as it is still the ai's turn
while GameClass.is_player_1_turn(root.position) == is_ai_player_1:
if root.children is None:
best_node = root.choose_expansion_node()
if best_node is not None:
best_node.expand()
root, distribution = root.choose_best_node(return_probability_distribution=True, optimal=True)
chosen_positions.append((root.position, distribution))

print('Expected outcome: ', root.get_evaluation())
root.parent = None # delete references to the parent and siblings
worker_pipe.send(chosen_positions)
if GameClass.is_over(root.position):
print('Game Over in Async MCTS: ', GameClass.get_winner(root.position))
break

def reset(self):
raise NotImplementedError('')
Loading