-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_pose_net_model.py
36 lines (24 loc) · 1.04 KB
/
train_pose_net_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
import tensorflow as tf
from core.pose_net_model import PoseNetModel
from core.simple_trainer import SimpleTrainer
from core.simple_predictor import SimplePredictor
from core.batcher import Batcher
def run():
train_batcher = Batcher(path_to_data='./data/images', path_to_csv='./data/train.csv', batch_size=4)
valid_batcher = Batcher(path_to_data='./data/images', path_to_csv='./data/valid.csv', batch_size=4)
pn_model = PoseNetModel()
pn_model.build_model()
sess = tf.Session()
trainer = SimpleTrainer(pn_model, train_batcher, valid_batcher, sess, 100, 10)
trainer.train()
test_batcher = Batcher(path_to_data='./data/images', path_to_csv='./data/test.csv', batch_size=4,
batcher_type='test')
predictor = SimplePredictor(pn_model, test_batcher, sess)
print('Start testing...')
mean_pos, mean_qua = predictor.test()
print('Mean pose error: {}, mean quaternion error: {}'.format(mean_pos, mean_qua))
def parse_args():
pass
if __name__ == '__main__':
run()