forked from SakhriHoussem/Image-Classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_vgg16.py
99 lines (82 loc) · 4.15 KB
/
train_vgg16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Simple tester for the vgg16_trainable
"""
from os import environ
import numpy as np
import tensorflow as tf
from dataSetGenerator import picShow,append
from vgg16 import vgg16_trainable as vgg16
import argparse
parser = argparse.ArgumentParser(prog="Train vgg16",description="Simple tester for the vgg16_trainable")
parser.add_argument('--dataset', metavar='dataset', type=str,required=True,
help='DataSet Name')
parser.add_argument('--batch', metavar='batch', type=int, default=10, help='batch size ')
parser.add_argument('--epochs', metavar='epochs', type=int, default=30,
help='number of epoch to train the network')
args = parser.parse_args()
classes_name = args.dataset
batch_size = args.batch
epochs = args.epochs
# batch_size = 10
# epochs = 30
# classes_name = "RSSCN7"
# classes_name = "UCMerced_LandUse_DU"
# classes_name = "SIRI-WHU"
classes = np.load("DataSets/{}_classes.npy".format(classes_name))
batch = np.load("DataSets/{}_dataTrain.npy".format(classes_name))
labels = np.load("DataSets/{}_labelsTrain.npy".format(classes_name))
classes_num = len(classes)
rib = batch.shape[1] # picture Rib
# with tf.device('/device:GPU:0'):
# with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
with tf.device('/cpu:0'):
with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=int(environ['NUMBER_OF_PROCESSORS']))) as sess:
images = tf.placeholder(tf.float32, [None, rib, rib, 3])
true_out = tf.placeholder(tf.float32, [None, classes_num])
train_mode = tf.placeholder(tf.bool)
try:
vgg = vgg16.Vgg16('Weights/VGG16_{}.npy'.format(classes_name),classes_num)
except:
print('Weights/VGG16_{}.npy Not Exist'.format(classes_name))
vgg = vgg16.Vgg16(None,classes_num)
vgg.build(images,train_mode)
# print number of variables used: 143667240 variables, i.e. ideal size = 548MB
print('number of variables used:',vgg.get_var_count())
sess.run(tf.global_variables_initializer())
# test classification
prob = sess.run(vgg.prob, feed_dict={images: batch[:10], train_mode: False})
picShow(batch[:10],labels[:10], classes, None, prob,True)
# simple 1-step training
cost = tf.reduce_sum((vgg.prob - true_out) ** 2)
train = tf.train.GradientDescentOptimizer(0.0001).minimize(cost)
correct_prediction = tf.equal(tf.argmax(prob), tf.argmax(true_out))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
batche_num = batch.shape[0]
costs = []
accs = []
for _ in range(epochs):
indice = np.random.permutation(batche_num)
counter = 0
for i in range(int(batche_num/batch_size)):
min_batch = indice[i*batch_size:(i+1)*batch_size]
cur_cost, cur_train,cur_acc= sess.run([cost, train,acc], feed_dict={images: batch[min_batch], true_out: labels[min_batch], train_mode: True})
print("Iteration :{} Batch :{} loss :{}".format(_, i, cur_cost))
accs.append(cur_acc)
costs.append(cur_cost)
counter += 1
if counter % 100 == 0:
# save graph data
append(costs,'Data/COST16_{}.txt'.format(classes_name))
append(accs,'Data/ACC16_{}.txt'.format(classes_name))
# save Weights
vgg.save_npy(sess, 'Weights/VGG16_{}.npy'.format(classes_name))
# save graph data
append(costs,'Data/COST16_{}.txt'.format(classes_name))
append(accs,'Data/ACC16_{}.txt'.format(classes_name))
# save Weights
vgg.save_npy(sess, 'Weights/VGG16_{}.npy'.format(classes_name))
# test classification again, should have a higher probability about tiger
prob = sess.run(vgg.prob, feed_dict={images: batch[:10], train_mode: False})
picShow(batch[:10],labels[:10], classes,None,prob)
# import subprocess
# subprocess.call(["shutdown", "/s"])