forked from dkappe/leela_lite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.py
116 lines (97 loc) · 2.72 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from lcztools import load_network, LeelaBoard
import search
import chess
import chess.pgn
import sys
logfile = open("leelalite.log", "w")
LOG = False
def log(str):
if LOG:
logfile.write(str)
logfile.write("\n")
logfile.flush()
def send(str):
log(">{}".format(str))
sys.stdout.write(str)
sys.stdout.write("\n")
sys.stdout.flush()
def process_position(tokens):
board = LeelaBoard()
offset = 0
if tokens[1] == 'startpos':
offset = 2
elif tokens[1] == 'fen':
fen = " ".join(tokens[2:8])
board = LeelaBoard(fen=fen)
offset = 8
if offset >= len(tokens):
return board
if tokens[offset] == 'moves':
for i in range(offset+1, len(tokens)):
board.push_uci(tokens[i])
return board
if len(sys.argv) == 3:
weights = sys.argv[1]
nodes = int(sys.argv[2])
type = "uct"
elif len(sys.argv) == 4:
weights = sys.argv[1]
nodes = int(sys.argv[2])
if sys.argv[3] == 'minimax':
type = 'minimax'
else:
type = 'uct'
else:
print("Usage: python3 engine.py <weights file or network server ID> <nodes>")
print(len(sys.argv))
exit(1)
network_id = None
try:
# If the parameter is an integer, assume it's a network server ID
network_id = int(weights)
weights = None
except:
pass
def load_leela_network():
global net, nn
if network_id is not None:
net = load_network(backend='net_client', network_id=network_id, policy_softmax_temp=2.2)
else:
net = load_network(backend='pytorch_cuda', filename=weights, policy_softmax_temp=2.2)
nn = search.NeuralNet(net=net, lru_size=max(5000, nodes))
send("Leela Lite")
board = LeelaBoard()
net = None
nn = None
while True:
line = sys.stdin.readline()
line = line.rstrip()
log("<{}".format(line))
tokens = line.split()
if len(tokens) == 0:
continue
if tokens[0] == "uci":
send('id name Leela Lite')
send('id author Dietrich Kappe')
send('uciok')
elif tokens[0] == "quit":
exit(0)
elif tokens[0] == "isready":
load_leela_network()
send("readyok")
elif tokens[0] == "ucinewgame":
board = LeelaBoard()
elif tokens[0] == 'position':
board = process_position(tokens)
elif tokens[0] == 'go':
my_nodes = nodes
if (len(tokens) == 3) and (tokens[1] == 'nodes'):
my_nodes = int(tokens[2])
if nn == None:
load_leela_network()
if type == 'uct':
best, node = search.UCT_search(board, my_nodes, net=nn, C=3.0)
else:
best, node = search.MinMax_search(board, my_nodes, net=nn, C=3.0)
send("bestmove {}".format(best))
logfile.close()