Skip to content
162 changes: 160 additions & 2 deletions src/move_selection/mcts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,128 @@
from abc import ABC, abstractmethod
from time import time
from multiprocessing import Process, Pipe, Pool
import numpy as np
from multiprocessing import Process, Pipe, Pool
from time import time, sleep


class AsyncMCTSRoot:
"""
Implementation of Monte Carlo Tree Search that uses the other player's time to continue thinking.
This is achieved using multiprocessing, and a Pipe for transferring data to and from the worker process.
The parallelization is done by doing several distinct MCTS searches completely in parallel, and aggregating the
results when a move is requested.

UI Pipe protocol: send user's move, send dummy message to request move, receive move, repeat
Worker Pipe protocol: send root node, send dummy message to request results, receive results, repeat
"""
def __init__(self, GameClass, position, time_limit=5, networks=None, c=np.sqrt(2), d=1, threads=1):
self.GameClass = GameClass
if networks is None:
self.root = RolloutNode(position, None, GameClass, c, rollout_batch_size=1, pool=None)
else:
self.root = HeuristicNode(position, None, GameClass, )
parent_worker_pipes, child_worker_pipes = zip(*[Pipe() for _ in range(threads)])

if networks is not None and threads != 1:
if threads != 1:
raise Exception()

self.worker_processes = [Process(target=self.worker_loop_func,
args=(worker_pipe, time_limit)) for worker_pipe in child_worker_pipes]

def start(self):
for worker_process in self.worker_processes:
worker_process.start()

def choose_move(self, user_chosen_position):
"""
Instructs the worker thread that the user has chosen the move specified by the given position.
The worker thread will then continue thinking for time_limit, and then return its chosen move.

:param user_chosen_position: The board position resulting from the user's move.
:return: The move chosen by monte carlo tree search.
"""
self.ui_pipe.send(user_chosen_position)
return self.ui_pipe.recv()

def terminate(self):
for worker_process in self.worker_processes:
worker_process.terminate()

for worker_process in self.worker_processes:
worker_process.join()

@staticmethod
def manager_loop_func(ui_pipe, worker_pipes, GameClass, position):
root = Node(position, None, GameClass)

while not GameClass.is_over(root.position):
for worker_pipe in worker_pipes:
worker_pipe.send(root)

user_move_index = ui_pipe.recv()
root.ensure_children()
root = root.children[user_move_index]
root.parent = None

if GameClass.is_over(root.position):
break
for worker_pipe in worker_pipes:
worker_pipe.send(user_move_index)

ui_pipe.recv() # dummy message
for worker_pipe in worker_pipes:
worker_pipe.send(None)

clones = [worker_pipe.recv() for worker_pipe in worker_pipes]
root.merge(clones)
root = root.choose_best_node()
ui_pipe.send(root.position)
root.parent = None

print('Game Over in Async MCTS Root: ', GameClass.get_winner(root.position))

@staticmethod
def worker_loop_func(worker_pipe, network):
"""
Worker thread workflow: receive MCTS root node, do MCTS until polled,
return root result, wait for new MCTS root node.
"""
if network is not None:
network.initialize()

while True:
root = worker_pipe.recv()

while (not worker_pipe.poll()) or root.children is None:
best_child = root.choose_expansion_node()
if best_child is None:
break

best_child.expand()

user_move_index = worker_pipe.recv()
root = root.children[user_move_index]
root.parent = None

# should in theory check to ensure that game is not over,
# but manager_process would have done that already, so its redundant

while not worker_pipe.poll():
best_child = root.choose_expansion_node()
if best_child is None:
break

best_child.expand()
worker_pipe.recv() # flush dummy message indicated that results should be returned
worker_pipe.send(root)


class AsyncMCTS:
"""
Implementation of Monte Carlo Tree Search that uses the other player's time to continue thinking.
This is achieved using multiprocessing, and a Pipe for transferring data to and from the worker process.
The parallelization of MCTS is done by doing several rollouts in parallel each time a rollout is requested,
waiting for all of them to finish, and aggregating the result.
"""

def __init__(self, GameClass, position, time_limit=3, network=None, c=np.sqrt(2), d=1, threads=1):
Expand Down Expand Up @@ -48,6 +163,7 @@ def loop_func(GameClass, position, time_limit, network, c, d, threads, worker_pi
root = RolloutNode(position, parent=None, GameClass=GameClass, c=c, rollout_batch_size=threads, pool=pool,
verbose=True)
else:
pool = None
network.initialize()
root = HeuristicNode(position, None, GameClass, network, c, d, verbose=True)

Expand Down Expand Up @@ -93,6 +209,10 @@ def loop_func(GameClass, position, time_limit, network, c, d, threads, worker_pi
print('Game Over in Async MCTS: ', GameClass.get_winner(root.position))
break

if pool is not None:
pool.close()
pool.join()


class MCTS:
"""
Expand Down Expand Up @@ -289,6 +409,27 @@ def depth_to_end_game(self):
return 1 + max(child.depth_to_end_game() for child in self.children
if child.fully_expanded and child.get_evaluation() == self.get_evaluation())

@abstractmethod
def merge(self, clones):
pass

def merge_children_clones(self, clones):
if self.children is None:
for i, clone in enumerate(clones):
if clone.children is not None:
# copy the clone's children to self's children, and remove it from list of clones
self.children = clone.children
clones = clones[i + 1:]
break
else:
# neither self nor any clones have children so we're done
return

# self has children (either because it originally had them,
# or they were copied from the first clone that had them)
for i, child in enumerate(self.children):
child.merge([clone.children[i] for clone in clones if clone.children is not None])


class RolloutNode(AbstractNode):
def __init__(self, position, parent, GameClass, c=np.sqrt(2), rollout_batch_size=1, pool=None, verbose=False):
Expand Down Expand Up @@ -345,6 +486,15 @@ def execute_single_rollout(self):

return self.GameClass.get_winner(state)

def merge(self, clones):
self.rollout_count += sum([clone.rollout_count - self.rollout_count
for clone in clones if clone.rollout_count > self.rollout_count])
self.rollout_sum += sum([clone.rollout_sum - self.rollout_sum
for clone in clones if clone.rollout_count > self.rollout_count])

# root nodes are cloned first, and changes propagate downwards to leafs
self.merge_children_clones(clones)


class HeuristicNode(AbstractNode):
def __init__(self, position, parent, GameClass, network, c=np.sqrt(2), d=1, network_call_results=None,
Expand Down Expand Up @@ -421,3 +571,11 @@ def ensure_children(self, moves=None, network_call_results=None):
network_call_results=network_call_result, verbose=self.verbose)
for move, network_call_result in zip(moves, network_call_results)]
self.expansions = 1

def merge(self, clones):
# leaf nodes are cloned first, and changes propagate upwards to root
self.merge_children_clones(clones)

self.heuristic = max([child.heuristic for child in self.children]) if self.is_maximizing else \
min([child.heuristic for child in self.children])
self.expansions = 1 + sum([child.expansions for child in self.children])