This repository has been archived by the owner on Jun 16, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 231
/
main.py
109 lines (94 loc) · 3.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
from contextlib import contextmanager
import os
import random
import re
import sys
import time
import argh
import gtp as gtp_lib
import tqdm
from policy import PolicyNetwork
from strategies import RandomPlayer, PolicyNetworkBestMovePlayer, PolicyNetworkRandomMovePlayer, MCTS
from load_data_sets import DataSet, parse_data_sets
TRAINING_CHUNK_RE = re.compile(r"train\d+\.chunk.gz")
@contextmanager
def timer(message):
tick = time.time()
yield
tock = time.time()
print("%s: %.3f" % (message, (tock - tick)))
def gtp(strategy, read_file=None):
n = PolicyNetwork(use_cpu=True)
if strategy == 'random':
instance = RandomPlayer()
elif strategy == 'policy':
instance = PolicyNetworkBestMovePlayer(n, read_file)
elif strategy == 'randompolicy':
instance = PolicyNetworkRandomMovePlayer(n, read_file)
elif strategy == 'mcts':
instance = MCTS(n, read_file)
else:
sys.stderr.write("Unknown strategy")
sys.exit()
gtp_engine = gtp_lib.Engine(instance)
sys.stderr.write("GTP engine ready\n")
sys.stderr.flush()
while not gtp_engine.disconnect:
inpt = input()
# handle either single lines at a time
# or multiple commands separated by '\n'
try:
cmd_list = inpt.split("\n")
except:
cmd_list = [inpt]
for cmd in cmd_list:
engine_reply = gtp_engine.send(cmd)
sys.stdout.write(engine_reply)
sys.stdout.flush()
def preprocess(*data_sets, processed_dir="processed_data"):
processed_dir = os.path.join(os.getcwd(), processed_dir)
if not os.path.isdir(processed_dir):
os.mkdir(processed_dir)
test_chunk, training_chunks = parse_data_sets(*data_sets)
print("Allocating %s positions as test; remainder as training" % len(test_chunk), file=sys.stderr)
print("Writing test chunk")
test_dataset = DataSet.from_positions_w_context(test_chunk, is_test=True)
test_filename = os.path.join(processed_dir, "test.chunk.gz")
test_dataset.write(test_filename)
print("Writing training chunks")
training_datasets = map(DataSet.from_positions_w_context, training_chunks)
for i, train_dataset in tqdm.tqdm(enumerate(training_datasets)):
train_filename = os.path.join(processed_dir, "train%s.chunk.gz" % i)
train_dataset.write(train_filename)
print("%s chunks written" % (i+1))
def train(processed_dir, read_file=None, save_file=None, epochs=10, logdir=None, checkpoint_freq=10000):
test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
train_chunk_files = [os.path.join(processed_dir, fname)
for fname in os.listdir(processed_dir)
if TRAINING_CHUNK_RE.match(fname)]
if read_file is not None:
read_file = os.path.join(os.getcwd(), save_file)
n = PolicyNetwork()
n.initialize_variables(read_file)
if logdir is not None:
n.initialize_logging(logdir)
last_save_checkpoint = 0
for i in range(epochs):
random.shuffle(train_chunk_files)
for file in train_chunk_files:
print("Using %s" % file)
with timer("load dataset"):
train_dataset = DataSet.read(file)
with timer("training"):
n.train(train_dataset)
with timer("save model"):
n.save_variables(save_file)
if n.get_global_step() > last_save_checkpoint + checkpoint_freq:
with timer("test set evaluation"):
n.check_accuracy(test_dataset)
last_save_checkpoint = n.get_global_step()
parser = argparse.ArgumentParser()
argh.add_commands(parser, [gtp, preprocess, train])
if __name__ == '__main__':
argh.dispatch(parser)