diff --git a/src/games/clock.py b/src/games/clock.py new file mode 100644 index 0000000..97143f7 --- /dev/null +++ b/src/games/clock.py @@ -0,0 +1,51 @@ +from time import time +from multiprocessing import Process, Pipe, Event + + +class Clock: + def __init__(self, starting_time=5 * 60, increment=5): + self.player_1_event = Event() + self.player_2_event = Event() + self.timing_update_pipe, worker_timing_update_pipe = Pipe(False) + self.timing_process = Process(target=Clock.timing_process_loop, args=(starting_time, increment, + self.player_1_event, self.player_2_event, + worker_timing_update_pipe)) + + def start(self): + self.timing_process.start() + + def terminate(self): + self.timing_process.terminate() + + def player_1_move(self): + self.player_1_event.set() + while self.player_1_event.is_set(): + pass + + def player_2_move(self): + self.player_2_event.set() + while self.player_2_event.is_set(): + pass + + @staticmethod + def timing_process_loop(starting_time, increment, player_1_event, player_2_event, timing_update_pipe): + player_1_time = starting_time + player_2_time = starting_time + while True: + + + move_start_time = time() + if player_1_event.wait(player_1_time): + player_1_time += increment - (time() - move_start_time) + player_1_event.clear() + else: + # player 1 timed out + break + + move_start_time = time() + if player_2_event.wait(player_1_time): + player_2_time += increment - (time() - move_start_time) + player_2_event.clear() + else: + # player 2 timed out + break diff --git a/src/move_selection/mcts.py b/src/move_selection/mcts.py index 84087e0..01ff9ed 100644 --- a/src/move_selection/mcts.py +++ b/src/move_selection/mcts.py @@ -74,15 +74,7 @@ def loop_func(GameClass, position, time_limit, network, c, d, threads, worker_pi print('Game Over in Async MCTS: ', GameClass.get_winner(root.position)) break - start_time = time() - while time() - start_time < time_limit: - best_node = root.choose_expansion_node() - - # best_node will be None if the tree is fully expanded - if best_node is None: - break - - best_node.expand() + AsyncMCTS.manage_time(root, time_limit) print(f'MCTS choosing move based on {root.count_expansions()} expansions!') root = root.choose_best_node(optimal=True) @@ -93,6 +85,43 @@ def loop_func(GameClass, position, time_limit, network, c, d, threads, worker_pi print('Game Over in Async MCTS: ', GameClass.get_winner(root.position)) break + @staticmethod + def manage_time(root, time_remaining, p_min=0.05, p_max=0.7, p_step=0.01, f_second=0.9): + """ + :param root: + :param time_remaining: The total amount of time remaining on the AI's clock. + :param p_min: The minimum fraction of the remaining time that should be used for this move. + :param p_max: The maximum fraction of the remaining time that should be used for this move. + :param p_step: + :param f_second: An upper bound estimate of the fraction of expansions that will occur + at the second most visited child of the root node. + """ + start_time = time() + time_elapsed = 0 + expansions = 0 + + for i, t_fraction in enumerate(np.concatenate(([p_min], np.linspace(p_min + p_step, p_max, + int(np.rint((p_max - p_min) / p_step)), + endpoint=True)))): + if i != 0: + visit_counts = sorted([child.count_expansions() for child in root.children]) + if len(visit_counts) < 2: + break + visits_to_surpass = visit_counts[-1] - visit_counts[-2] + if visits_to_surpass > f_second * expansions / time_elapsed * (p_max * time_remaining - time_elapsed): + break + + while time_elapsed < t_fraction * time_remaining: + best_node = root.choose_expansion_node() + + # best_node will be None if the tree is fully expanded + if best_node is None: + return + + best_node.expand() + expansions += 1 + time_elapsed = time() - start_time + class MCTS: """