-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
263 lines (226 loc) · 12.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
################################################################################
######### Train PSP models #########
################################################################################
#importing required modules and dependancies
import os
import argparse
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, CSVLogger, LearningRateScheduler
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.compat.v1.keras.backend import set_session
import time
import importlib
import json
from json.decoder import JSONDecodeError
from psp.dataset import *
from psp._globals import model_output, OUTPUT_DIR, current_datetime
from psp.plot_model import *
from psp.evaluate import *
from psp.utils import *
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #reduce TF log output to only include Errors
### Tensorboard parameters and configuration ###
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session() # For easy reset of notebook state.
config_proto = tf.compat.v1.ConfigProto()
config_proto.allow_soft_placement = True
off = rewriter_config_pb2.RewriterConfig.OFF
config_proto.gpu_options.allow_growth = True
config_proto.graph_options.rewrite_options.arithmetic_optimization = off
#set tensorflow GPUOptions so TF doesn't overload GPU if present
# config_proto.gpu_options(per_process_gpu_memory_fraction=0.333)
session = tf.compat.v1.Session(config=config_proto)
# tf.Session(config=tf.compat.v1.ConfigProto(log_device_placement=True))
set_session(session)
#get model filenames from models directory
def remove_py(x): return os.path.splitext(x)[0]
# remove_py = lambda x: os.path.splitext(x)[0]
all_models = list(map(remove_py,([f for f in os.listdir(os.path.join('psp','models')) if os.path.isfile(os.path.join('psp','models', f)) and f[:3] == 'psp'] +
([f for f in os.listdir(os.path.join('psp','models','auxiliary_models')) if os.path.isfile(os.path.join('psp','models','auxiliary_models', f)) \
and f[:3] == 'psp']))))
all_models.append('dummy_model')
#main starting function for PSP code pipeline
def main(args):
"""
Description:
Main function for training, evaluating and plotting PSP models.
Args:
:args (dict): parsed input arguments.
Returns:
None
"""
#append config filepath to user input config file
#strip file extension (if exists) from input filename and append .json to it
config_file = os.path.join("config",os.path.splitext(args.config)[0]+'.json')
#open JSON config file in config folder
try:
if not os.path.isfile(config_file):
raise OSError('JSON config file not found at path: {}.'.format(config_file))
with open(config_file) as f:
params = json.load(f)
except JSONDecodeError as e:
print('Error getting config JSON file: {}.'.format(config_file))
#parse input arguments from json config file
training_data = params["parameters"][0]["training_data"]
filtered = params["parameters"][0]["filtered"]
batch_size = int(params["parameters"][0]["batch_size"])
epochs = int(params["parameters"][0]["epochs"])
learning_rate = float(params["model_parameters"][0]["optimizer"]["learning_rate"])
logs_path = str(params["parameters"][0]["logs_path"])
cuda = params["parameters"][0]["cuda"]
test_dataset = str(params["parameters"][0]["test_dataset"])
model_ = str(params["parameters"][0]["model"])
tf_version = tf.__version__
callbacks = (params["model_parameters"][0]["callbacks"])
lr_scheduler = str(callbacks["lrScheduler"]["scheduler"]) #str(params["model_parameters"][0]["lr_scheduler"])
save_h5 = params["parameters"][0]["save_h5"]
#set model output dict to values in config
model_output["Config"] = os.path.basename(config_file)
model_output["Model"] = model_
model_output["Training Dataset Type"] = training_data
model_output["Filtered?"] = filtered
model_output["Test Dataset"] = test_dataset
model_output["Number of epochs"] = epochs
model_output["Batch size"] = batch_size
model_output["Tensorflow Version"] = tf_version
model_output["Cuda"] = cuda
model_output["LR Scheduler"] = lr_scheduler
print("\n###################################################################")
print("Running model locally with parameters...\n")
print("Configuration File: {}".format(config_file))
print("Training Data: {} (filtered: {})".format(training_data, filtered))
print("Test Dataset: {}".format(test_dataset))
print("Model: {}".format(model_))
print("Batch Size: {}".format(batch_size))
print("Epochs: {}".format(epochs))
print("Learning Rate: {}".format(learning_rate))
print("Logs Path: {}".format(logs_path))
print("Cuda: {}".format(cuda))
print('Callbacks:')
print(' TensorBoard: {}'.format(callbacks["tensorboard"]["tensorboard"]))
print(' Early Stopping: {}'.format(callbacks["earlyStopping"]["earlyStopping"]))
print(' Model Checkpoint: {}'.format(callbacks["modelCheckpoint"]["modelCheckpoint"]))
print(' Learning Rate Scheduler: {}'.format(callbacks["lrScheduler"]["lrScheduler"]))
if (callbacks["lrScheduler"]["lrScheduler"]):
print(' Scheduler: {}'.format(lr_scheduler))
print(' CSV Logger: {}'.format(callbacks["tensorboard"]["tensorboard"]))
print(' Reduce LR on Plateau: {}'.format(callbacks["tensorboard"]["tensorboard"]))
print("###################################################################\n")
#verify model specified in config file exists in available models
if model_ not in all_models:
raise ValueError('Model {} must be in available models: \n {}.'.format(
model_, all_models))
#load cullPDB training dataset
cullpdb = CullPDB(type=training_data, filtered=filtered)
#import model module from models or auxillary models folder
if (model_ != "psp_dcblstm_model" and model_ != "psp_dculstm_model" and model_ != "dummy_model"):
mod = importlib.import_module("psp.models.auxiliary_models."+model_)
else:
mod = importlib.import_module("psp.models."+model_)
#build model
model = mod.build_model(params["model_parameters"][0])
#create saved_models directory where trained models will be stored
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
#create folder where all model assets and artifacts will be stored after training
model_folder_path = os.path.join(os.path.join(os.getcwd(), OUTPUT_DIR),
model_ + '_'+ current_datetime)
os.makedirs(model_folder_path)
#create logs path directory where TensorBoard logs will be stored
if not os.path.exists(os.path.join(model_folder_path, logs_path)):
os.makedirs(os.path.join(model_folder_path, logs_path))
#create checkpoints dir where model checkpoints will be saved
if not os.path.exists(os.path.join(model_folder_path, 'checkpoints')):
os.makedirs(os.path.join(model_folder_path, 'checkpoints'))
all_callbacks = []
#initialise Tensorflow callbacks
#append each callback if used
if (callbacks["tensorboard"]["tensorboard"]):
callbacks["tensorboard"].pop('tensorboard')
tensorboard = TensorBoard(log_dir=(os.path.join(model_folder_path,logs_path)), **callbacks["tensorboard"])
all_callbacks.append(tensorboard)
if (callbacks["earlyStopping"]["earlyStopping"]):
callbacks["earlyStopping"].pop('earlyStopping')
earlyStopping = EarlyStopping(**callbacks["earlyStopping"])
all_callbacks.append(earlyStopping)
if (callbacks["modelCheckpoint"]["modelCheckpoint"]):
callbacks["modelCheckpoint"].pop('modelCheckpoint')
checkpoint = ModelCheckpoint(filepath=os.path.join(model_folder_path, 'checkpoints','model_' + current_datetime + '.h5'), **callbacks["modelCheckpoint"])
all_callbacks.append(checkpoint)
if (callbacks["csv_logger"]["csv_logger"]):
callbacks["csv_logger"].pop('csv_logger')
csv_logger = CSVLogger(filename=os.path.join(model_folder_path, 'training.log'), **callbacks["csv_logger"])
all_callbacks.append(csv_logger)
if (callbacks["reduceLROnPlateau"]["reduceLROnPlateau"]):
callbacks["reduceLROnPlateau"].pop('reduceLROnPlateau')
reduceLROnPlateau = ReduceLROnPlateau(**callbacks["reduceLROnPlateau"])
all_callbacks.append(reduceLROnPlateau)
#get LR Scheduler callback to use from parameter in config file
#remove any whitespace or '-' from lr_schedule name
lr_scheduler = lr_scheduler.lower().strip().replace(" ", "").replace("-","")
if (lr_scheduler == "exceptionaldecay" or lr_scheduler == "exponential"):
exponentialDecay = ExponentialDecay()
lr_schedule = LearningRateScheduler(exponentialDecay)
all_callbacks.append(lr_schedule)
elif (lr_scheduler == "timebaseddecay" or lr_scheduler == "timebased"):
timeBasedDecay = TimedBased()
lr_schedule = LearningRateScheduler(timeBasedDecay)
all_callbacks.append(lr_schedule)
elif (lr_scheduler == "stepdecay" or lr_scheduler == "exponential"):
stepDecay = StepDecay()
lr_schedule = LearningRateScheduler(stepDecay)
all_callbacks.append(lr_schedule)
#start time func to measure the training time
start = time.time()
#fit model
if cuda:
with tf.device('/gpu:0'): #if training on GPU
print('Fitting model...')
history = model.fit({'main_input': cullpdb.train_hot, 'aux_input': cullpdb.trainpssm},
{'main_output': cullpdb.trainlabel},validation_data=({'main_input': cullpdb.val_hot, 'aux_input': cullpdb.valpssm},
{'main_output': cullpdb.vallabel}), epochs=epochs, batch_size=batch_size, verbose=2,
callbacks=all_callbacks,shuffle=True)
else: #training on CPU (default)
print('Fitting model...')
history = model.fit({'main_input': cullpdb.train_hot, 'aux_input': cullpdb.trainpssm},
{'main_output': cullpdb.trainlabel},validation_data=({'main_input': cullpdb.val_hot, 'aux_input': cullpdb.valpssm},
{'main_output': cullpdb.vallabel}), epochs=epochs, batch_size=batch_size, verbose=2,
callbacks=all_callbacks,shuffle=True)
#calculate elapsed training time
elapsed = (time.time() - start)
print('Elapsed Training Time: {}'.format(elapsed))
#append training time to output file
model_output['Training Time'] = elapsed
#save trained model in either hdf5 or SavedModel format
if (save_h5):
model.save(os.path.join(model_folder_path, 'model.h5'))
else:
model.save(os.path.join(model_folder_path, 'model'))
#save model history pickle
save_history(history, os.path.join(model_folder_path, 'history.pckl'))
#plot history
plot_history(history.history, model_folder_path, show_histograms = True,
show_boxplots = True, show_kde = True, filter_outliers = True, save = True)
#visualise Keras model and all its layers, store in png in output folder
visualise_model(model, model_folder_path)
#evaluating model
evaluate_cullpdb(model,cullpdb)
evaluate_model(model, test_dataset=test_dataset)
#getting output results from model into csv
model_output_df = get_model_output(model_folder_path)
#save model architecture
with open(os.path.join(model_folder_path, "model_architecture.json"), "w") as model_arch:
model_arch.write(model.to_json(indent=3))
print('Model training files exported to local path: {} '.format(model_folder_path))
#close tensorflow session
session.close()
if __name__ == "__main__":
#############################################################
### PSP Input Arguments ###
#############################################################
parser = argparse.ArgumentParser(description='Protein Secondary Structure Prediction.')
parser.add_argument('-config', '--config', type=str, required=True, help='File path to config json file.')
#parse input args
args = parser.parse_args()
main(args)