From 418b7e781ccd734f0aad9948d3af1c50c2fe4797 Mon Sep 17 00:00:00 2001 From: prickly-u Date: Tue, 25 Feb 2020 15:30:05 +0300 Subject: [PATCH 1/2] Pass preprocess_image function to generators in the evaluate.py file --- keras_retinanet/bin/evaluate.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/keras_retinanet/bin/evaluate.py b/keras_retinanet/bin/evaluate.py index 53695aea9..da093dcb9 100755 --- a/keras_retinanet/bin/evaluate.py +++ b/keras_retinanet/bin/evaluate.py @@ -28,6 +28,7 @@ from .. import models from ..preprocessing.csv_generator import CSVGenerator from ..preprocessing.pascal_voc import PascalVocGenerator +from ..utils.anchors import make_shapes_callback from ..utils.config import read_config_file, parse_anchor_parameters from ..utils.eval import evaluate from ..utils.gpu import setup_gpu @@ -35,9 +36,13 @@ from ..utils.tf_version import check_tf_version -def create_generator(args): +def create_generator(args, preprocess_image): """ Create generators for evaluation. """ + common_args = { + 'preprocess_image': preprocess_image, + } + if args.dataset_type == 'coco': # import here to prevent unnecessary dependency on cocoapi from ..preprocessing.coco import CocoGenerator @@ -49,6 +54,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) elif args.dataset_type == 'pascal': validation_generator = PascalVocGenerator( @@ -59,6 +65,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) elif args.dataset_type == 'csv': validation_generator = CSVGenerator( @@ -68,6 +75,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) else: raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) @@ -109,7 +117,7 @@ def parse_args(args): def main(args=None): - # parse arguments + # parse argumentsbin if args is None: args = sys.argv[1:] args = parse_args(args) @@ -131,7 +139,8 @@ def main(args=None): args.config = read_config_file(args.config) # create the generator - generator = create_generator(args) + backbone = models.backbone(args.backbone) + generator = create_generator(args, backbone.preprocess_image) # optionally load anchor parameters anchor_params = None @@ -141,6 +150,7 @@ def main(args=None): # load the model print('Loading model, this may take a second...') model = models.load_model(args.model, backbone_name=args.backbone) + generator.compute_shapes = make_shapes_callback(model) # optionally convert the model if args.convert_model: From 5baaa7c4c69ae286f773318014559072502d739c Mon Sep 17 00:00:00 2001 From: prickly-u Date: Tue, 25 Feb 2020 15:39:27 +0300 Subject: [PATCH 2/2] Fixed typo --- keras_retinanet/bin/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_retinanet/bin/evaluate.py b/keras_retinanet/bin/evaluate.py index da093dcb9..90c095c75 100755 --- a/keras_retinanet/bin/evaluate.py +++ b/keras_retinanet/bin/evaluate.py @@ -117,7 +117,7 @@ def parse_args(args): def main(args=None): - # parse argumentsbin + # parse arguments if args is None: args = sys.argv[1:] args = parse_args(args)