-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
78 lines (73 loc) · 2.79 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python
from alpha_net import ChessNet, train
from MCTS_chess import MCTS_self_play
import os
import pickle
import numpy as np
import torch
import torch.multiprocessing as mp
if __name__=="__main__":
for iteration in range(10):
# Runs MCTS
net_to_play="current_net_trained8_iter1.pth.tar"
mp.set_start_method("spawn",force=True)
net = ChessNet()
cuda = torch.cuda.is_available()
if cuda:
net.cuda()
net.share_memory()
net.eval()
print("hi")
current_net_filename = os.path.join("./model_data/",\
net_to_play)
checkpoint = torch.load(current_net_filename)
net.load_state_dict(checkpoint['state_dict'])
processes1 = []
for i in range(6):
p1 = mp.Process(target=MCTS_self_play,args=(net,50,i))
p1.start()
processes1.append(p1)
for p1 in processes1:
p1.join()
# Runs Net training
net_to_train="current_net_trained8_iter1.pth.tar"; save_as="current_net_trained8_iter1.pth.tar"
# gather data
data_path = "./datasets/iter0/"
datasets = []
for idx,file in enumerate(os.listdir(data_path)):
filename = os.path.join(data_path,file)
with open(filename, 'rb') as fo:
datasets.extend(pickle.load(fo, encoding='bytes'))
data_path = "./datasets/iter1/"
for idx,file in enumerate(os.listdir(data_path)):
filename = os.path.join(data_path,file)
with open(filename, 'rb') as fo:
datasets.extend(pickle.load(fo, encoding='bytes'))
data_path = "./datasets/iter2/"
for idx,file in enumerate(os.listdir(data_path)):
filename = os.path.join(data_path,file)
with open(filename, 'rb') as fo:
datasets.extend(pickle.load(fo, encoding='bytes'))
datasets = np.array(datasets)
mp.set_start_method("spawn",force=True)
net = ChessNet()
cuda = torch.cuda.is_available()
if cuda:
net.cuda()
net.share_memory()
net.train()
print("hi")
current_net_filename = os.path.join("./model_data/",\
net_to_train)
checkpoint = torch.load(current_net_filename)
net.load_state_dict(checkpoint['state_dict'])
processes2 = []
for i in range(6):
p2 = mp.Process(target=train,args=(net,datasets,0,200,i))
p2.start()
processes2.append(p2)
for p2 in processes2:
p2.join()
# save results
torch.save({'state_dict': net.state_dict()}, os.path.join("./model_data/",\
save_as))