Skip to content

Commit

Permalink
Merge pull request #1290 from prickly-u/eval_fix
Browse files Browse the repository at this point in the history
Pass preprocess_image function to generators in evaluate.py
  • Loading branch information
hgaiser authored Mar 2, 2020
2 parents 460461a + 5baaa7c commit a81b313
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions keras_retinanet/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@
from .. import models
from ..preprocessing.csv_generator import CSVGenerator
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..utils.anchors import make_shapes_callback
from ..utils.config import read_config_file, parse_anchor_parameters
from ..utils.eval import evaluate
from ..utils.gpu import setup_gpu
from ..utils.keras_version import check_keras_version
from ..utils.tf_version import check_tf_version


def create_generator(args):
def create_generator(args, preprocess_image):
""" Create generators for evaluation.
"""
common_args = {
'preprocess_image': preprocess_image,
}

if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
from ..preprocessing.coco import CocoGenerator
Expand All @@ -49,6 +54,7 @@ def create_generator(args):
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'pascal':
validation_generator = PascalVocGenerator(
Expand All @@ -59,6 +65,7 @@ def create_generator(args):
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'csv':
validation_generator = CSVGenerator(
Expand All @@ -68,6 +75,7 @@ def create_generator(args):
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
**common_args
)
else:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))
Expand Down Expand Up @@ -131,7 +139,8 @@ def main(args=None):
args.config = read_config_file(args.config)

# create the generator
generator = create_generator(args)
backbone = models.backbone(args.backbone)
generator = create_generator(args, backbone.preprocess_image)

# optionally load anchor parameters
anchor_params = None
Expand All @@ -141,6 +150,7 @@ def main(args=None):
# load the model
print('Loading model, this may take a second...')
model = models.load_model(args.model, backbone_name=args.backbone)
generator.compute_shapes = make_shapes_callback(model)

# optionally convert the model
if args.convert_model:
Expand Down

0 comments on commit a81b313

Please sign in to comment.