forked from xavireig/NeuralChessAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_staged.py
More file actions
104 lines (85 loc) · 4.36 KB
/
train_staged.py
File metadata and controls
104 lines (85 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import gc
import datetime
import numpy as np
import h5py
from keras import callbacks
from keras.layers import Dense, Flatten, Input, Activation, Dropout
from keras.layers.convolutional import Conv2D
from keras.models import Sequential, load_model, model_from_json
from keras.optimizers import Adam
import chess
import tensorflow as tf
from utils import chess_dict, squares
from batch_generator import batch_generator
# fix TF 2.4 issue
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
tf.config.experimental.set_memory_growth(device, True)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
stages = ['early','mid','late']
variants = ['from','to']
network = '256-256-1024-1024'
batch_size = 1024
for stage in stages:
for variant in variants:
print('Training network '+stage+'-'+variant)
data_path = 'data\\data2014-2020-staged-'+stage+'.h5'
validation_data_path = 'data\\data2013-staged-'+stage+'.h5'
model_path = 'models/'+stage+'-'+variant+'-withTurn-b'+str(batch_size)+'-'+network+'-model.h5'
# model_path = 'models/to-withTurn-b2048-128-256-1024-1024-model.h5'
model_path_json = 'models/'+stage+'-'+variant+'-withTurn-b'+str(batch_size)+'-'+network+'-model.json'
# log_dir = "logs/" + variant + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/'+stage+'-'+variant+'-b'+str(batch_size)+'-'+network
# training data
with h5py.File(data_path, "r") as h5f:
num_samples = len(h5f['moved_'+variant+'_'+stage])
steps_per_epoch = np.ceil(num_samples/batch_size)
training_data_batch_generator = batch_generator(data_path, batch_size, steps_per_epoch, variant, stage)
# validation data
with h5py.File(validation_data_path, "r") as h5f:
validation_num_samples = len(h5f['moved_'+variant+'_'+stage])
validation_steps_per_epoch = np.ceil(validation_num_samples/batch_size)
validation_data_batch_generator = batch_generator(validation_data_path, batch_size, validation_steps_per_epoch, variant, stage)
# initalize neural network
model = Sequential()
model.add(Input(shape=(9, 8, 12)))
model.add(Conv2D(256, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(256, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(1024))
model.add(Activation('relu'))
# 64 classes as output (one-hot encoded position on the board)
model.add(Dense(64, activation='softmax'))
# in case we need to resume training
# model = load_model(model_path)
# decay of learning rate to avoid overfitting
# lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.001, decay_steps=100000, decay_rate=0.96, staircase=True)
model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])
# # save model for later
model_json = model.to_json()
with open(model_path_json, 'w') as json_file:
json_file.write(model_json)
# print(model.summary())
# checkpoints in case training stops
checkpoint = callbacks.ModelCheckpoint(model_path, monitor='val_accuracy', verbose=2, save_best_only=True, mode='max', save_freq='epoch')
# stops early if loss doesn't improve after 500 epocs
# es = callbacks.EarlyStopping(monitor='val_accuracy', mode='min', verbose=1, patience=500)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model.fit(training_data_batch_generator,
steps_per_epoch=steps_per_epoch,
validation_data=validation_data_batch_generator,
validation_steps=validation_steps_per_epoch,
# initial_epoch=10, # in case we need to resume training
epochs=10,
verbose=1,
shuffle=True,
# workers=4,
# max_queue_size=8,
callbacks=[checkpoint, tensorboard_callback])