diff --git a/keras_retinanet/bin/evaluate.py b/keras_retinanet/bin/evaluate.py index 53695aea9..90c095c75 100755 --- a/keras_retinanet/bin/evaluate.py +++ b/keras_retinanet/bin/evaluate.py @@ -28,6 +28,7 @@ from .. import models from ..preprocessing.csv_generator import CSVGenerator from ..preprocessing.pascal_voc import PascalVocGenerator +from ..utils.anchors import make_shapes_callback from ..utils.config import read_config_file, parse_anchor_parameters from ..utils.eval import evaluate from ..utils.gpu import setup_gpu @@ -35,9 +36,13 @@ from ..utils.tf_version import check_tf_version -def create_generator(args): +def create_generator(args, preprocess_image): """ Create generators for evaluation. """ + common_args = { + 'preprocess_image': preprocess_image, + } + if args.dataset_type == 'coco': # import here to prevent unnecessary dependency on cocoapi from ..preprocessing.coco import CocoGenerator @@ -49,6 +54,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) elif args.dataset_type == 'pascal': validation_generator = PascalVocGenerator( @@ -59,6 +65,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) elif args.dataset_type == 'csv': validation_generator = CSVGenerator( @@ -68,6 +75,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) else: raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) @@ -131,7 +139,8 @@ def main(args=None): args.config = read_config_file(args.config) # create the generator - generator = create_generator(args) + backbone = models.backbone(args.backbone) + generator = create_generator(args, backbone.preprocess_image) # optionally load anchor parameters anchor_params = None @@ -141,6 +150,7 @@ def main(args=None): # load the model print('Loading model, this may take a second...') model = models.load_model(args.model, backbone_name=args.backbone) + generator.compute_shapes = make_shapes_callback(model) # optionally convert the model if args.convert_model: