From 4e08699fba5b91109c0aba7a5f7d2fe9d3e39455 Mon Sep 17 00:00:00 2001 From: Prakhar Thapak Date: Mon, 2 Apr 2018 13:18:47 +0530 Subject: [PATCH] Testing can be done on the whole set of images. Now testing can be done on the whole set of images inside the testing images directory instead of testing single image one by one. --- version2_predict.py | 104 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 version2_predict.py diff --git a/version2_predict.py b/version2_predict.py new file mode 100644 index 000000000..b47d0ae92 --- /dev/null +++ b/version2_predict.py @@ -0,0 +1,104 @@ +#! /usr/bin/env python + +import argparse +import os +import cv2 +import numpy as np +from tqdm import tqdm +from preprocessing import parse_annotation +from utils import draw_boxes +from frontend import YOLO +import json +import glob + +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"]="0" + +argparser = argparse.ArgumentParser( + description='Train and validate YOLO_v2 model on any dataset') + +argparser.add_argument( + '-c', + '--conf', + help='path to configuration file') + +argparser.add_argument( + '-w', + '--weights', + help='path to pretrained weights') + +argparser.add_argument( + '-i', + '--input', + help='path to an image or an video (mp4 format)') + +def _main_(args): + config_path = args.conf + weights_path = args.weights + image_path = args.input + + with open(config_path) as config_buffer: + config = json.load(config_buffer) + + ############################### + # Make the model + ############################### + + yolo = YOLO(backend = config['model']['backend'], + input_size = config['model']['input_size'], + labels = config['model']['labels'], + max_box_per_image = config['model']['max_box_per_image'], + anchors = config['model']['anchors']) + + ############################### + # Load trained weights + ############################### + + yolo.load_weights(weights_path) + + ############################### + # Predict bounding boxes + ############################### + + if image_path[-4:] == '.mp4': + video_out = image_path[:-4] + '_detected' + image_path[-4:] + video_reader = cv2.VideoCapture(image_path) + + nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH)) + + video_writer = cv2.VideoWriter(video_out, + cv2.VideoWriter_fourcc(*'MPEG'), + 50.0, + (frame_w, frame_h)) + + for i in tqdm(range(nb_frames)): + _, image = video_reader.read() + + boxes = yolo.predict(image) + image = draw_boxes(image, boxes, config['model']['labels']) + + video_writer.write(np.uint8(image)) + + video_reader.release() + video_writer.release() + else: + for fnamee in glob.glob(image_path+"*.jpg"): + cache=fnamee.split("/")[0] + fname=fnamee.split("/")[-1] + + + + image = cv2.imread(fnamee) + boxes = yolo.predict(image) + image = draw_boxes(image, boxes, config['model']['labels']) + + print(len(boxes), 'boxes are found') + + + cv2.imwrite(cache+"/"+fname[:-4] + '_detected' + fname[-4:], image) + +if __name__ == '__main__': + args = argparser.parse_args() + _main_(args)