-
Notifications
You must be signed in to change notification settings - Fork 2
/
neuralNetworkTrainer.py
65 lines (50 loc) · 1.77 KB
/
neuralNetworkTrainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from uwimg import *
import time
import sys
def softmax_model(inputs, outputs):
l = [make_layer(inputs, outputs, SOFTMAX)]
return make_model(l)
def neural_net(inputs, outputs):
print (inputs)
l = [make_layer(inputs, 32, RELU),
make_layer(32, 16,RELU),
#make_layer(16, 8, RELU),
make_layer(16, outputs, SOFTMAX)]
return make_model(l)
def save_to_file(filepath, process_to_run):
orig_stdout = sys.stdout
with open('filepath', 'w') as f:
sys.stdout = f
process_to_run
sys.stdout = orig_stdout
f.close()
def neural_on_mnist_dataset():
start_time = time.time()
print("loading data...")
train_file_path = "mnist.train"
labels_path = "mnist/mnist.labels"
test_file_path = "mnist.test"
train = load_classification_data(c_char_p(train_file_path.encode('utf-8')), c_char_p(labels_path.encode('utf-8')), 1)
test = load_classification_data(c_char_p(test_file_path.encode('utf-8')),c_char_p(labels_path.encode('utf-8')) , 1)
print("Loading Data done")
print()
print("training model...")
batch = 128
iters = 1000
rate = .01
momentum = .9
decay = .0005
m = neural_net(train.X.cols, train.y.cols)
print(train_model(m, train, batch, iters, rate, momentum, decay))
print("Training done")
training_time = time.time()
print('Training Ttime took {}'.format(training_time - start_time))
testing_time_start = time.time()
print("evaluating model...")
print("training accuracy: %f", accuracy_model(m, train))
print("test accuracy: %f", accuracy_model(m, test))
testing_time_ends = time.time()
print('Testing took {}'.format(testing_time_ends - testing_time_start))
print('Total Time Taken = {}'.format(testing_time_ends - start_time))
if __name__ == "__main__":
save_to_file("results/mnist_Relu_Activation", neural_on_mnist_dataset())