-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_fcn.py
105 lines (84 loc) · 3.64 KB
/
test_fcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 21 16:41:52 2018
@author: LiHongWang
"""
import skimage.io as io
import skimage.transform as transform
import numpy as np
import tensorflow as tf
from model import fcn_resnet_v2
from data import input_data
import cv2
slim = tf.contrib.slim
num_classes=2
def test():
logs_train_dir = '/home/Public/seg_project/slim_seg/um/lane/u_arg/'
batchSize=1
data_sources = 'D:/dataSet/kitti/road/uu_val30.tfrecord'
num_samples = 30
samples=input_data.get_images_labels(data_sources,2, num_samples,batch_size=1)
batch_queue = slim.prefetch_queue.prefetch_queue(samples,capacity=128 )
tra_batch = batch_queue.dequeue()
# images = tra_batch['image']
x = tf.placeholder(tf.float32, shape=[1,224, 224, 3])#batchSize,
_,logit=fcn_resnet_v2.fcn_res101(x,2,is_training=False)
# logit = tf.nn.softmax(logit)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success')
else:
print('No checkpoint file found')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess= sess, coord=coord)
try:
for step in range(1):
if step%100==0:
print("step: %d done !"%step)
if coord.should_stop():
break
# images,labels = sess.run([image_batch, label_batch])
# name = tra_batch['image_name']
# height=tra_batch['height']
# width=tra_batch['width']
#
# print(images,name)
# sess.run(tra_batch)
# image=sess.run([images])
# print(image)
# image=np.reshape(image,[1, 256, 256, 3])
image=cv2.imread('D:/dataSet/kitti/road/data_road/testing/uu/uu_000000.png')
# io.imshow(image)
image=cv2.resize(image,(224,224))
print('here !')
image=tf.cast(image,tf.float32)
image=tf.image.per_image_standardization(image)
image=sess.run(image)
image=np.reshape(image,[1, 224, 224, 3])
print('123 !')
prediction = sess.run(logit, feed_dict={x: image})
print('OK !')
# prediction = np.squeeze(prediction, axis=3)
for i in range(batchSize):
img=prediction[i]
# image=transform.resize(image,(375,1242))
# img=img.astype(np.uint8)
# np.reshape(img,(256,256,2)).astype(np.uint8)
#
## name=labels[i]
#
# io.imshow(img)
# io.imsave(dst_dir+'.png',img)
return prediction,img
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
if __name__=='__main__':
pred,img=test()