forked from SakhriHoussem/Image-Classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_vgg19_distibuted.py
89 lines (77 loc) · 3.82 KB
/
train_vgg19_distibuted.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Simple tester for the Tensorflow distributed
"""
from os import environ
import numpy as np
import tensorflow as tf
from Networking import ClusterGen
from dataSetGenerator import append
from dataSetGenerator import loadClasses
from dataSetGenerator import picShow
from vgg19 import vgg19_trainable as vgg19
batch = np.load("DataSets/UCMerced_LandUse_dataTrain.npy")
labels = np.load("DataSets/UCMerced_LandUse_labelsTrain.npy")
classes = loadClasses("DataSets/UCMerced_LandUse.txt")
classes_num = len(classes)
rib = batch.shape[1] # picture Rib
workers = ['DESKTOP-07HFBQN','FOUZI-PC']
pss = ['DELL-MINI']
index = workers.index(environ['COMPUTERNAME']) if environ['COMPUTERNAME'] in workers else pss.index(environ['COMPUTERNAME']) if environ['COMPUTERNAME'] in pss else None
job = 'worker' if environ['COMPUTERNAME'] in workers else 'ps' if environ['COMPUTERNAME'] in pss else None
cluster = tf.train.ClusterSpec(ClusterGen(workers,pss))
server = tf.train.Server(cluster, job_name=job, task_index=index)
# with tf.device(tf.train.replica_device_setter(
# worker_device="/job:ps/task:"+str(index),
# cluster=cluster)):
with tf.device("/job:ps/task:{}".format(index)):
images = tf.placeholder(tf.float32, [None, rib, rib, 3])
true_out = tf.placeholder(tf.float32, [None, len(classes)])
train_mode = tf.placeholder(tf.bool)
try:
vgg = vgg19.Vgg19('Weights/VGG19_{}C.npy'.format(classes_num),len(classes))
except:
vgg = vgg19.Vgg19(None,len(classes))
vgg.build(images,train_mode)
global_step = tf.train.get_or_create_global_step()
inc_global_step = tf.assign(global_step, global_step + 1)
if job == 'worker':
# with tf.device(tf.train.replica_device_setter(
# worker_device="/job:worker/task:"+str(index),
# cluster=cluster)):
with tf.device("/job:worker/task:{}".format(index)):
cost = tf.reduce_sum((vgg.prob - true_out) ** 2)
train = tf.train.GradientDescentOptimizer(0.0001).minimize(cost)
correct_prediction = tf.equal(tf.argmax(cost), tf.argmax(true_out))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
batch_size = 5
batche_num = len(batch)
# tf.train.global_step(sess, tf.Variable(10, trainable=False, name='global_step'))
hooks=[tf.train.StopAtStepHook(last_step=2)]
with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(index == 0), hooks=hooks) as sess:
while not sess.should_stop():
_ = sess.run(inc_global_step)
costs = []
accs = []
print("******************* ", _, " *******************")
indice = np.random.permutation(batche_num)
for i in range(int(batche_num/batch_size)):
print('step 1')
min_batch = indice[i*batch_size:(i+1)*batch_size]
print('step 2')
cur_cost, cur_train,cur_acc= sess.run([cost, train,acc], feed_dict={images: batch[min_batch], true_out: labels[min_batch], train_mode: True})
print("Iteration %d loss:\n%s" % (i, cur_cost))
costs.append(str(cur_cost)+'\n')
accs.append(str(cur_acc)+'\n')
# with tf.device(tf.train.replica_device_setter(
# worker_device="/job:ps/task:"+str(index),
# cluster=cluster)):
append(costs,'Data/cost19_{}C_D'.format(classes_num))
append(accs,'Data/acc19_{}C_D'.format(classes_num))
vgg.save_npy(sess, 'Weights/VGG19_{}C_D.npy'.format(classes_num))
# test classification
prob = sess.run(vgg.prob, feed_dict={images: batch[:10], train_mode: False})
picShow(batch[:10],labels[:10], classes, None, prob)
elif job == 'ps':
server.join()
else:
print("error JOB")